{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save intermediate feature-maps to file\n",
    "\n",
    "This short notebook performs a forward-pass on a single ImageNet image, and save the intermediate results to a file for later inspection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "# Relative import of code from distiller, w/o installing the package\n",
    "import os\n",
    "import sys\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "\n",
    "import distiller\n",
    "import distiller.apputils as apputils\n",
    "import distiller.models as models\n",
    "from distiller.apputils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load ResNet50\n",
    "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n",
    "\n",
    "# Load the \"sparse\" compressed model \n",
    "resnet50_sparse = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=True)\n",
    "load_checkpoint(resnet50_sparse , \"../examples/classifier_compression/resnet50_checkpoint_70.66-sparse_76.09-top1.pth.tar\");\n",
    "resnet50_sparse  = distiller.make_non_parallel_copy(resnet50_sparse)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "\n",
    "normalize = transforms.Normalize(\n",
    "   mean=[0.485, 0.456, 0.406],\n",
    "   std=[0.229, 0.224, 0.225]\n",
    ")\n",
    "preprocess = transforms.Compose([\n",
    "   transforms.Resize(256),\n",
    "   transforms.CenterCrop(224),\n",
    "   transforms.ToTensor(),\n",
    "   normalize\n",
    "])\n",
    "\n",
    "# Open some file and do a forward-pass.  We do this to check everything is cool.\n",
    "img = Image.open(\"../../data.imagenet/val/n02669723/ILSVRC2012_val_00029372.JPEG\")\n",
    "img = preprocess(img)\n",
    "img = img.unsqueeze_(0)\n",
    "\n",
    "sample = Variable(img)\n",
    "\n",
    "resnet50_sparse.eval()\n",
    "predictions = resnet50_sparse(sample.cuda())\n",
    "\n",
    "top1_vals, top1_indices = torch.max(predictions, 1)\n",
    "print(top1_vals.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# This is where the work gets done:\n",
    "# We register for forward callbacks, perform a forward pass, and in the callbacks we cache the\n",
    "# intermediate feature-maps.\n",
    "\n",
    "intermediary_results = {}\n",
    "cb_handles = []\n",
    "\n",
    "def save_conv_output(m, i, o):\n",
    "    intermediary_results[m.distiller_name] = o.data.cpu().numpy()\n",
    "    \n",
    "def register_forward_hook(m):\n",
    "    if hasattr(m, 'distiller_name'):\n",
    "        h = m.register_forward_hook(save_conv_output)\n",
    "        cb_handles.append(h)\n",
    "\n",
    "distiller.utils.assign_layer_fq_names(resnet50_sparse)\n",
    "resnet50_sparse.apply(register_forward_hook)\n",
    "predictions = resnet50_sparse(sample.cuda())\n",
    "for h in cb_handles:\n",
    "    h.remove()\n",
    "print(len(intermediary_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now save the feature-maps to file...\n",
    "import pickle \n",
    "\n",
    "with open('resnet50_sparse_fms.pickle', 'wb') as f:\n",
    "    pickle.dump(intermediary_results, f)\n",
    "    #np_weights = pickle.save(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# This is how you load results later...\n",
    "\n",
    "intermediary_results_from_file = {}\n",
    "\n",
    "with open('resnet50_sparse_fms.pickle', 'rb') as f:\n",
    "    intermediary_results_from_file = pickle.load(f)\n",
    " \n",
    "# To print all layer names:\n",
    "for name, tensor in intermediary_results_from_file.items():\n",
    "    print(name, tensor.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
